In [94]:
!pip install catalyst --upgrade
!pip install --user git+https://github.com/albu/albumentations@bdd6a4e
!pip install --user git+https://github.com/qubvel/segmentation_models.pytorch
Requirement already up-to-date: catalyst in c:\users\shane\lib\site-packages (20.1.3)
Requirement already satisfied, skipping upgrade: torch>=1.0.0 in c:\users\shane\lib\site-packages (from catalyst) (1.4.0)
Requirement already satisfied, skipping upgrade: tensorboard>=1.14.0 in c:\users\shane\lib\site-packages (from catalyst) (2.1.0)
Requirement already satisfied, skipping upgrade: tensorboardX in c:\users\shane\lib\site-packages (from catalyst) (2.0)
Requirement already satisfied, skipping upgrade: crc32c>=1.7 in c:\users\shane\lib\site-packages (from catalyst) (2.0)
Requirement already satisfied, skipping upgrade: safitty>=1.2.3 in c:\users\shane\lib\site-packages (from catalyst) (1.3)
Requirement already satisfied, skipping upgrade: matplotlib in c:\users\shane\lib\site-packages (from catalyst) (3.1.2)
Requirement already satisfied, skipping upgrade: tqdm>=4.33.0 in c:\users\shane\lib\site-packages (from catalyst) (4.42.0)
Requirement already satisfied, skipping upgrade: PyYAML in c:\users\shane\lib\site-packages (from catalyst) (5.3)
Requirement already satisfied, skipping upgrade: opencv-python in c:\users\shane\lib\site-packages (from catalyst) (4.1.2.30)
Requirement already satisfied, skipping upgrade: Pillow<7 in c:\users\shane\lib\site-packages (from catalyst) (6.2.2)
Requirement already satisfied, skipping upgrade: numpy>=1.16.4 in c:\users\shane\lib\site-packages (from catalyst) (1.18.1)
Requirement already satisfied, skipping upgrade: packaging in c:\users\shane\lib\site-packages (from catalyst) (20.1)
Requirement already satisfied, skipping upgrade: scikit-learn>=0.20 in c:\users\shane\lib\site-packages (from catalyst) (0.22.1)
Requirement already satisfied, skipping upgrade: torchvision>=0.2.1 in c:\users\shane\lib\site-packages (from catalyst) (0.5.0)
Requirement already satisfied, skipping upgrade: ipython in c:\users\shane\lib\site-packages (from catalyst) (7.11.1)
Requirement already satisfied, skipping upgrade: pandas>=0.22 in c:\users\shane\lib\site-packages (from catalyst) (1.0.0)
Requirement already satisfied, skipping upgrade: seaborn in c:\users\shane\lib\site-packages (from catalyst) (0.10.0)
Requirement already satisfied, skipping upgrade: imageio in c:\users\shane\lib\site-packages (from catalyst) (2.6.1)
Requirement already satisfied, skipping upgrade: GitPython>=2.1.11 in c:\users\shane\lib\site-packages (from catalyst) (3.0.5)
Requirement already satisfied, skipping upgrade: scikit-image>=0.14.2 in c:\users\shane\lib\site-packages (from catalyst) (0.16.2)
Requirement already satisfied, skipping upgrade: plotly>=4.1.0 in c:\users\shane\lib\site-packages (from catalyst) (4.5.0)
Requirement already satisfied, skipping upgrade: protobuf>=3.6.0 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (3.11.2)
Requirement already satisfied, skipping upgrade: setuptools>=41.0.0 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (45.1.0)
Requirement already satisfied, skipping upgrade: google-auth<2,>=1.6.3 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (1.11.0)
Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (0.16.1)
Requirement already satisfied, skipping upgrade: grpcio>=1.24.3 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (1.26.0)
Requirement already satisfied, skipping upgrade: six>=1.10.0 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (1.14.0)
Requirement already satisfied, skipping upgrade: requests<3,>=2.21.0 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (2.22.0)
Requirement already satisfied, skipping upgrade: absl-py>=0.4 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (0.9.0)
Requirement already satisfied, skipping upgrade: wheel>=0.26; python_version >= "3" in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (0.33.6)
Requirement already satisfied, skipping upgrade: google-auth-oauthlib<0.5,>=0.4.1 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (0.4.1)
Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (3.1.1)
Requirement already satisfied, skipping upgrade: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in c:\users\shane\lib\site-packages (from matplotlib->catalyst) (2.4.6)
Requirement already satisfied, skipping upgrade: kiwisolver>=1.0.1 in c:\users\shane\lib\site-packages (from matplotlib->catalyst) (1.1.0)
Requirement already satisfied, skipping upgrade: cycler>=0.10 in c:\users\shane\lib\site-packages (from matplotlib->catalyst) (0.10.0)
Requirement already satisfied, skipping upgrade: python-dateutil>=2.1 in c:\users\shane\lib\site-packages (from matplotlib->catalyst) (2.8.1)
Requirement already satisfied, skipping upgrade: scipy>=0.17.0 in c:\users\shane\lib\site-packages (from scikit-learn>=0.20->catalyst) (1.4.1)
Requirement already satisfied, skipping upgrade: joblib>=0.11 in c:\users\shane\lib\site-packages (from scikit-learn>=0.20->catalyst) (0.14.1)
Requirement already satisfied, skipping upgrade: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in c:\users\shane\lib\site-packages (from ipython->catalyst) (3.0.3)
Requirement already satisfied, skipping upgrade: decorator in c:\users\shane\lib\site-packages (from ipython->catalyst) (4.4.1)
Requirement already satisfied, skipping upgrade: colorama; sys_platform == "win32" in c:\users\shane\lib\site-packages (from ipython->catalyst) (0.4.3)
Requirement already satisfied, skipping upgrade: pickleshare in c:\users\shane\lib\site-packages (from ipython->catalyst) (0.7.5)
Requirement already satisfied, skipping upgrade: jedi>=0.10 in c:\users\shane\lib\site-packages (from ipython->catalyst) (0.16.0)
Requirement already satisfied, skipping upgrade: traitlets>=4.2 in c:\users\shane\lib\site-packages (from ipython->catalyst) (4.3.3)
Requirement already satisfied, skipping upgrade: backcall in c:\users\shane\lib\site-packages (from ipython->catalyst) (0.1.0)
Requirement already satisfied, skipping upgrade: pygments in c:\users\shane\lib\site-packages (from ipython->catalyst) (2.5.2)
Requirement already satisfied, skipping upgrade: pytz>=2017.2 in c:\users\shane\lib\site-packages (from pandas>=0.22->catalyst) (2019.3)
Requirement already satisfied, skipping upgrade: gitdb2>=2.0.0 in c:\users\shane\lib\site-packages (from GitPython>=2.1.11->catalyst) (2.0.6)
Requirement already satisfied, skipping upgrade: networkx>=2.0 in c:\users\shane\lib\site-packages (from scikit-image>=0.14.2->catalyst) (2.4)
Requirement already satisfied, skipping upgrade: PyWavelets>=0.4.0 in c:\users\shane\lib\site-packages (from scikit-image>=0.14.2->catalyst) (1.1.1)
Requirement already satisfied, skipping upgrade: retrying>=1.3.3 in c:\users\shane\lib\site-packages (from plotly>=4.1.0->catalyst) (1.3.3)
Requirement already satisfied, skipping upgrade: pyasn1-modules>=0.2.1 in c:\users\shane\lib\site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14.0->catalyst) (0.2.8)
Requirement already satisfied, skipping upgrade: rsa<4.1,>=3.1.4 in c:\users\shane\lib\site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14.0->catalyst) (4.0)
Requirement already satisfied, skipping upgrade: cachetools<5.0,>=2.0.0 in c:\users\shane\lib\site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14.0->catalyst) (4.0.0)
Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in c:\users\shane\lib\site-packages (from requests<3,>=2.21.0->tensorboard>=1.14.0->catalyst) (2.7)
Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in c:\users\shane\lib\site-packages (from requests<3,>=2.21.0->tensorboard>=1.14.0->catalyst) (2018.1.18)
Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in c:\users\shane\lib\site-packages (from requests<3,>=2.21.0->tensorboard>=1.14.0->catalyst) (3.0.4)
Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in c:\users\shane\lib\site-packages (from requests<3,>=2.21.0->tensorboard>=1.14.0->catalyst) (1.24.3)
Requirement already satisfied, skipping upgrade: requests-oauthlib>=0.7.0 in c:\users\shane\lib\site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14.0->catalyst) (1.3.0)
Requirement already satisfied, skipping upgrade: wcwidth in c:\users\shane\lib\site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython->catalyst) (0.1.8)
Requirement already satisfied, skipping upgrade: parso>=0.5.2 in c:\users\shane\lib\site-packages (from jedi>=0.10->ipython->catalyst) (0.6.0)
Requirement already satisfied, skipping upgrade: ipython-genutils in c:\users\shane\lib\site-packages (from traitlets>=4.2->ipython->catalyst) (0.2.0)
Requirement already satisfied, skipping upgrade: smmap2>=2.0.0 in c:\users\shane\lib\site-packages (from gitdb2>=2.0.0->GitPython>=2.1.11->catalyst) (2.0.5)
Requirement already satisfied, skipping upgrade: pyasn1<0.5.0,>=0.4.6 in c:\users\shane\lib\site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard>=1.14.0->catalyst) (0.4.8)
Requirement already satisfied, skipping upgrade: oauthlib>=3.0.0 in c:\users\shane\lib\site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14.0->catalyst) (3.1.0)
Collecting git+https://github.com/albu/albumentations@bdd6a4e
  Cloning https://github.com/albu/albumentations (to revision bdd6a4e) to c:\users\shane\appdata\local\temp\pip-req-build-4fbyt0fz
Requirement already satisfied: numpy>=1.11.1 in c:\users\shane\lib\site-packages (from albumentations==0.2.2) (1.18.1)
Requirement already satisfied: scipy in c:\users\shane\lib\site-packages (from albumentations==0.2.2) (1.4.1)
Collecting opencv-python-headless
  Using cached opencv_python_headless-4.1.2.30-cp36-cp36m-win_amd64.whl (33.0 MB)
Requirement already satisfied: imgaug<0.2.7,>=0.2.5 in c:\users\shane\lib\site-packages (from albumentations==0.2.2) (0.2.6)
Requirement already satisfied: six in c:\users\shane\lib\site-packages (from imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (1.14.0)
Requirement already satisfied: scikit-image>=0.11.0 in c:\users\shane\lib\site-packages (from imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (0.16.2)
Requirement already satisfied: networkx>=2.0 in c:\users\shane\lib\site-packages (from scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (2.4)
Requirement already satisfied: imageio>=2.3.0 in c:\users\shane\lib\site-packages (from scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (2.6.1)
Requirement already satisfied: PyWavelets>=0.4.0 in c:\users\shane\lib\site-packages (from scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (1.1.1)
Requirement already satisfied: pillow>=4.3.0 in c:\users\shane\lib\site-packages (from scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (6.2.2)
Requirement already satisfied: matplotlib!=3.0.0,>=2.0.0 in c:\users\shane\lib\site-packages (from scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (3.1.2)
Requirement already satisfied: decorator>=4.3.0 in c:\users\shane\lib\site-packages (from networkx>=2.0->scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (4.4.1)
Requirement already satisfied: python-dateutil>=2.1 in c:\users\shane\lib\site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (2.8.1)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in c:\users\shane\lib\site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (2.4.6)
Requirement already satisfied: cycler>=0.10 in c:\users\shane\lib\site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in c:\users\shane\lib\site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (1.1.0)
Requirement already satisfied: setuptools in c:\users\shane\lib\site-packages (from kiwisolver>=1.0.1->matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (45.1.0)
Building wheels for collected packages: albumentations
  Building wheel for albumentations (setup.py): started
  Building wheel for albumentations (setup.py): finished with status 'done'
  Created wheel for albumentations: filename=albumentations-0.2.2-py3-none-any.whl size=40824 sha256=b2e3ffaadcca719c262909ef28a7e8a0f248e338e42d26f0644e4a822a11d1d7
  Stored in directory: C:\Users\Shane\AppData\Local\Temp\pip-ephem-wheel-cache-borrdapx\wheels\31\7a\63\9e858e89b0e44cb4f3621b0ce0c077363fbe546b04b1dcc0ba
Successfully built albumentations
Installing collected packages: opencv-python-headless, albumentations
Successfully installed albumentations-0.2.2 opencv-python-headless-4.1.2.30
  Running command git clone -q https://github.com/albu/albumentations 'C:\Users\Shane\AppData\Local\Temp\pip-req-build-4fbyt0fz'
  WARNING: Did not find branch or tag 'bdd6a4e', assuming revision or ref.
  Running command git checkout -q bdd6a4e
Collecting git+https://github.com/qubvel/segmentation_models.pytorch
  Cloning https://github.com/qubvel/segmentation_models.pytorch to c:\users\shane\appdata\local\temp\pip-req-build-81_4swz8
Requirement already satisfied (use --upgrade to upgrade): segmentation-models-pytorch==0.1.0 from git+https://github.com/qubvel/segmentation_models.pytorch in c:\users\shane\lib\site-packages
Requirement already satisfied: torchvision>=0.3.0 in c:\users\shane\lib\site-packages (from segmentation-models-pytorch==0.1.0) (0.5.0)
Requirement already satisfied: pretrainedmodels==0.7.4 in c:\users\shane\lib\site-packages (from segmentation-models-pytorch==0.1.0) (0.7.4)
Requirement already satisfied: efficientnet-pytorch>=0.5.1 in c:\users\shane\lib\site-packages (from segmentation-models-pytorch==0.1.0) (0.6.1)
Requirement already satisfied: pillow>=4.1.1 in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (6.2.2)
Requirement already satisfied: numpy in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (1.18.1)
Requirement already satisfied: six in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (1.14.0)
Requirement already satisfied: torch==1.4.0 in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (1.4.0)
Requirement already satisfied: tqdm in c:\users\shane\lib\site-packages (from pretrainedmodels==0.7.4->segmentation-models-pytorch==0.1.0) (4.42.0)
Requirement already satisfied: munch in c:\users\shane\lib\site-packages (from pretrainedmodels==0.7.4->segmentation-models-pytorch==0.1.0) (2.5.0)
Building wheels for collected packages: segmentation-models-pytorch
  Building wheel for segmentation-models-pytorch (setup.py): started
  Building wheel for segmentation-models-pytorch (setup.py): finished with status 'done'
  Created wheel for segmentation-models-pytorch: filename=segmentation_models_pytorch-0.1.0-py3-none-any.whl size=47303 sha256=9ac583953fc7e0a1fbecd2fa936f171fe00869e093e4d8becd6008405cc741f0
  Stored in directory: C:\Users\Shane\AppData\Local\Temp\pip-ephem-wheel-cache-nwv_e3p5\wheels\53\e5\fc\18292d80d3c0f4efc96cbbb72625fdbafdca303997bacfb085
Successfully built segmentation-models-pytorch
  Running command git clone -q https://github.com/qubvel/segmentation_models.pytorch 'C:\Users\Shane\AppData\Local\Temp\pip-req-build-81_4swz8'
In [96]:
!pip install git+https://github.com/qubvel/segmentation_models.pytorch
Collecting git+https://github.com/qubvel/segmentation_models.pytorch
  Cloning https://github.com/qubvel/segmentation_models.pytorch to c:\users\shane\appdata\local\temp\pip-req-build-ycb81jmm
Requirement already satisfied (use --upgrade to upgrade): segmentation-models-pytorch==0.1.0 from git+https://github.com/qubvel/segmentation_models.pytorch in c:\users\shane\lib\site-packages
Requirement already satisfied: torchvision>=0.3.0 in c:\users\shane\lib\site-packages (from segmentation-models-pytorch==0.1.0) (0.5.0)
Requirement already satisfied: pretrainedmodels==0.7.4 in c:\users\shane\lib\site-packages (from segmentation-models-pytorch==0.1.0) (0.7.4)
Requirement already satisfied: efficientnet-pytorch>=0.5.1 in c:\users\shane\lib\site-packages (from segmentation-models-pytorch==0.1.0) (0.6.1)
Requirement already satisfied: torch==1.4.0 in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (1.4.0)
Requirement already satisfied: numpy in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (1.18.1)
Requirement already satisfied: six in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (1.14.0)
Requirement already satisfied: pillow>=4.1.1 in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (6.2.2)
Requirement already satisfied: tqdm in c:\users\shane\lib\site-packages (from pretrainedmodels==0.7.4->segmentation-models-pytorch==0.1.0) (4.42.0)
Requirement already satisfied: munch in c:\users\shane\lib\site-packages (from pretrainedmodels==0.7.4->segmentation-models-pytorch==0.1.0) (2.5.0)
Building wheels for collected packages: segmentation-models-pytorch
  Building wheel for segmentation-models-pytorch (setup.py): started
  Building wheel for segmentation-models-pytorch (setup.py): finished with status 'done'
  Created wheel for segmentation-models-pytorch: filename=segmentation_models_pytorch-0.1.0-py3-none-any.whl size=47303 sha256=56ddfc20cd708a02e4574f9f1b0631132032479e178a18bdcb8d2297455726de
  Stored in directory: C:\Users\Shane\AppData\Local\Temp\pip-ephem-wheel-cache-jjy8xgvx\wheels\53\e5\fc\18292d80d3c0f4efc96cbbb72625fdbafdca303997bacfb085
Successfully built segmentation-models-pytorch
  Running command git clone -q https://github.com/qubvel/segmentation_models.pytorch 'C:\Users\Shane\AppData\Local\Temp\pip-req-build-ycb81jmm'
In [1]:
import os
import cv2
import collections
import time 
import tqdm
from PIL import Image
from functools import partial
train_on_gpu = True
In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
In [90]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

import torchvision
import torchvision.transforms as transforms
import torch
from torch.utils.data import TensorDataset, DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data.sampler import SubsetRandomSampler
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR



import segmentation_models_pytorch as smp
model = smp.Unet()
Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to C:\Users\Shane/.cache\torch\checkpoints\resnet34-333f7ec4.pth

In [68]:
from albumentations import Compose, RandomCrop, Normalize, HorizontalFlip, Resize
from albumentations.pytorch import ToTensor
import albumentations as albu
from albumentations import pytorch as AT
In [148]:
from catalyst.dl.utils import criterion
In [151]:
from catalyst.data import Augmentor
from catalyst.dl import utils
from catalyst.data.reader import ImageReader, ScalarReader, ReaderCompose, LambdaReader
from catalyst.dl.runner import SupervisedRunner
#from catalyst.contrib.models.segmentation import Unet
from catalyst.dl.callbacks import DiceCallback, EarlyStoppingCallback, InferCallback, CheckpointCallback
In [44]:
def get_img(x, folder: str='train_images'):
    """
    Return image based on image name and folder.
    """
    data_folder = f"{path}/{folder}"
    image_path = os.path.join(data_folder, x)
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img


def rle_decode(mask_rle: str = '', shape: tuple = (1400, 2100)):
    '''
    Decode rle encoded mask.
    
    :param mask_rle: run-length as string formatted (start length)
    :param shape: (height, width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape, order='F')


def make_mask(df: pd.DataFrame, image_name: str='img.jpg', shape: tuple = (1400, 2100)):
    """
    Create mask based on df, image name and shape.
    """
    encoded_masks = df.loc[df['im_id'] == image_name, 'EncodedPixels']
    masks = np.zeros((shape[0], shape[1], 4), dtype=np.float32)

    for idx, label in enumerate(encoded_masks.values):
        if label is not np.nan:
            mask = rle_decode(label)
            masks[:, :, idx] = mask
            
    return masks


def to_tensor(x, **kwargs):
    """
    Convert image or mask.
    """
    return x.transpose(2, 0, 1).astype('float32')


def mask2rle(img):
    '''
    Convert mask to rle.
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels= img.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def visualize(image, mask, original_image=None, original_mask=None):
    """
    Plot image and masks.
    If two pairs of images and masks are passes, show both.
    """
    fontsize = 14
    class_dict = {0: 'Fish', 1: 'Flower', 2: 'Gravel', 3: 'Sugar'}
    
    if original_image is None and original_mask is None:
        f, ax = plt.subplots(1, 5, figsize=(24, 24))

        ax[0].imshow(image)
        for i in range(4):
            ax[i + 1].imshow(mask[:, :, i])
            ax[i + 1].set_title(f'Mask {class_dict[i]}', fontsize=fontsize)
    else:
        f, ax = plt.subplots(2, 5, figsize=(24, 12))

        ax[0, 0].imshow(original_image)
        ax[0, 0].set_title('Original image', fontsize=fontsize)
                
        for i in range(4):
            ax[0, i + 1].imshow(original_mask[:, :, i])
            ax[0, i + 1].set_title(f'Original mask {class_dict[i]}', fontsize=fontsize)
        
        ax[1, 0].imshow(image)
        ax[1, 0].set_title('Transformed image', fontsize=fontsize)
        
        
        for i in range(4):
            ax[1, i + 1].imshow(mask[:, :, i])
            ax[1, i + 1].set_title(f'Transformed mask {class_dict[i]}', fontsize=fontsize)
            
            
def visualize_with_raw(image, mask, original_image=None, original_mask=None, raw_image=None, raw_mask=None):
    """
    Plot image and masks.
    If two pairs of images and masks are passes, show both.
    """
    fontsize = 14
    class_dict = {0: 'Fish', 1: 'Flower', 2: 'Gravel', 3: 'Sugar'}

    f, ax = plt.subplots(3, 5, figsize=(24, 12))

    ax[0, 0].imshow(original_image)
    ax[0, 0].set_title('Original image', fontsize=fontsize)

    for i in range(4):
        ax[0, i + 1].imshow(original_mask[:, :, i])
        ax[0, i + 1].set_title(f'Original mask {class_dict[i]}', fontsize=fontsize)


    ax[1, 0].imshow(raw_image)
    ax[1, 0].set_title('Original image', fontsize=fontsize)

    for i in range(4):
        ax[1, i + 1].imshow(raw_mask[:, :, i])
        ax[1, i + 1].set_title(f'Raw predicted mask {class_dict[i]}', fontsize=fontsize)
        
    ax[2, 0].imshow(image)
    ax[2, 0].set_title('Transformed image', fontsize=fontsize)


    for i in range(4):
        ax[2, i + 1].imshow(mask[:, :, i])
        ax[2, i + 1].set_title(f'Predicted mask with processing {class_dict[i]}', fontsize=fontsize)
            
            
def plot_with_augmentation(image, mask, augment):
    """
    Wrapper for `visualize` function.
    """
    augmented = augment(image=image, mask=mask)
    image_flipped = augmented['image']
    mask_flipped = augmented['mask']
    visualize(image_flipped, mask_flipped, original_image=image, original_mask=mask)
    
    
sigmoid = lambda x: 1 / (1 + np.exp(-x))


def post_process(probability, threshold, min_size):
    """
    Post processing of each predicted mask, components with lesser number of pixels
    than `min_size` are ignored
    """
    # don't remember where I saw it
    mask = cv2.threshold(probability, threshold, 1, cv2.THRESH_BINARY)[1]
    num_component, component = cv2.connectedComponents(mask.astype(np.uint8))
    predictions = np.zeros((350, 525), np.float32)
    num = 0
    for c in range(1, num_component):
        p = (component == c)
        if p.sum() > min_size:
            predictions[p] = 1
            num += 1
    return predictions, num


def get_training_augmentation():
    train_transform = [

        albu.HorizontalFlip(p=0.5),
        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=0.5, border_mode=0),
        albu.GridDistortion(p=0.5),
        albu.OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5),
        albu.Resize(320, 640)
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.Resize(320, 640)
    ]
    return albu.Compose(test_transform)


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)


def dice(img1, img2):
    img1 = np.asarray(img1).astype(np.bool)
    img2 = np.asarray(img2).astype(np.bool)

    intersection = np.logical_and(img1, img2)

    return 2. * intersection.sum() / (img1.sum() + img2.sum())
In [45]:
path = '../cosmology'
os.listdir(path)
Out[45]:
['.ipynb_checkpoints',
 'BTC Mk1.ipynb',
 'crypto_tradinds.csv',
 'cumulative.csv',
 'exoTest.csv',
 'exoTrain.csv',
 'historical-data-on-the-trading-of-cryptocurrencies.zip',
 'New folder',
 'sample_submission.csv',
 'test_images',
 'train.csv',
 'train_images',
 'Untitled.ipynb',
 'Untitled1.ipynb']
In [46]:
train = pd.read_csv(f'{path}/train.csv')
sub = pd.read_csv(f'{path}/sample_submission.csv')
In [47]:
train.head()
Out[47]:
Image_Label EncodedPixels
0 0011165.jpg_Fish 264918 937 266318 937 267718 937 269118 937 27...
1 0011165.jpg_Flower 1355565 1002 1356965 1002 1358365 1002 1359765...
2 0011165.jpg_Gravel NaN
3 0011165.jpg_Sugar NaN
4 002be4f.jpg_Fish 233813 878 235213 878 236613 878 238010 881 23...
In [48]:
n_train = len(os.listdir(f'{path}/train_images'))
n_test = len(os.listdir(f'{path}/test_images'))
print(f'There are {n_train} images in train dataset')
print(f'There are {n_test} images in test dataset')
There are 5546 images in train dataset
There are 3698 images in test dataset
In [49]:
train['Image_Label'].apply(lambda x: x.split('_')[1]).value_counts()
Out[49]:
Sugar     5546
Flower    5546
Gravel    5546
Fish      5546
Name: Image_Label, dtype: int64
In [50]:
train.loc[train['EncodedPixels'].isnull() == False, 'Image_Label'].apply(lambda x: x.split('_')[1]).value_counts()
Out[50]:
Sugar     3751
Gravel    2939
Fish      2781
Flower    2365
Name: Image_Label, dtype: int64
In [51]:
train.loc[train['EncodedPixels'].isnull() == False, 'Image_Label'].apply(lambda x: x.split('_')[0]).value_counts().value_counts()
Out[51]:
2    2372
3    1560
1    1348
4     266
Name: Image_Label, dtype: int64
In [52]:
train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[1])
train['im_id'] = train['Image_Label'].apply(lambda x: x.split('_')[0])


sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[1])
sub['im_id'] = sub['Image_Label'].apply(lambda x: x.split('_')[0])
In [53]:
fig = plt.figure(figsize=(25, 16))
for j, im_id in enumerate(np.random.choice(train['im_id'].unique(), 4)):
    for i, (idx, row) in enumerate(train.loc[train['im_id'] == im_id].iterrows()):
        ax = fig.add_subplot(5, 4, j * 4 + i + 1, xticks=[], yticks=[])
        im = Image.open(f"{path}/train_images/{row['Image_Label'].split('_')[0]}")
        plt.imshow(im)
        mask_rle = row['EncodedPixels']
        try: # label might not be there!
            mask = rle_decode(mask_rle)
        except:
            mask = np.zeros((1400, 2100))
        plt.imshow(mask, alpha=0.5, cmap='gray')
        ax.set_title(f"Image: {row['Image_Label'].split('_')[0]}. Label: {row['label']}")
In [54]:
id_mask_count = train.loc[train['EncodedPixels'].isnull() == False, 'Image_Label'].apply(lambda x: x.split('_')[0]).value_counts().\
reset_index().rename(columns={'index': 'img_id', 'Image_Label': 'count'})
train_ids, valid_ids = train_test_split(id_mask_count['img_id'].values, random_state=42, stratify=id_mask_count['count'], test_size=0.1)
test_ids = sub['Image_Label'].apply(lambda x: x.split('_')[0]).drop_duplicates().values
In [55]:
image_name = '8242ba0.jpg'
image = get_img(image_name)
mask = make_mask(train, image_name)
In [56]:
visualize(image, mask)
In [57]:
plot_with_augmentation(image, mask, albu.HorizontalFlip(p=1))
In [58]:
plot_with_augmentation(image, mask, albu.VerticalFlip(p=1))
In [59]:
plot_with_augmentation(image, mask, albu.RandomRotate90(p=1))
In [60]:
plot_with_augmentation(image, mask, albu.ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03))
In [63]:
plot_with_augmentation(image, mask, albu.GridDistortion(p=0.1))
In [64]:
plot_with_augmentation(image, mask, albu.OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5))
In [71]:
class CloudDataset(Dataset):
    def __init__(self, df: pd.DataFrame = None, datatype: str = 'train', img_ids: np.array = None,
                 transforms = albu.Compose([albu.HorizontalFlip(),AT.ToTensor()]),
                preprocessing=None):
        self.df = df
        if datatype != 'test':
            self.data_folder = f"{path}/train_images"
        else:
            self.data_folder = f"{path}/test_images"
        self.img_ids = img_ids
        self.transforms = transforms
        self.preprocessing = preprocessing

    def __getitem__(self, idx):
        image_name = self.img_ids[idx]
        mask = make_mask(self.df, image_name)
        image_path = os.path.join(self.data_folder, image_name)
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        augmented = self.transforms(image=img, mask=mask)
        img = augmented['image']
        mask = augmented['mask']
        if self.preprocessing:
            preprocessed = self.preprocessing(image=img, mask=mask)
            img = preprocessed['image']
            mask = preprocessed['mask']
        return img, mask

    def __len__(self):
        return len(self.img_ids)
In [72]:
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
DEVICE = 'cuda'

ACTIVATION = None
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=4, 
    activation=ACTIVATION,
)
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
In [144]:
num_workers = 0
bs = 16
train_dataset = CloudDataset(df=train, datatype='train', img_ids=train_ids, transforms = get_training_augmentation(), preprocessing=get_preprocessing(preprocessing_fn))
valid_dataset = CloudDataset(df=train, datatype='valid', img_ids=valid_ids, transforms = get_validation_augmentation(), preprocessing=get_preprocessing(preprocessing_fn))

train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False, num_workers=num_workers)

loaders = {
    "train": train_loader,
    "valid": valid_loader
}
In [145]:
 
In [158]:
num_epochs = 19
logdir = "./logs/segmentation"

# model, criterion, optimizer
optimizer = torch.optim.Adam([
    {'params': model.decoder.parameters(), 'lr': 1e-2}, 
    {'params': model.encoder.parameters(), 'lr': 1e-3},  
])

scheduler = ReduceLROnPlateau(optimizer, factor=0.15, patience=2)
#criterion = smp.utils.losses.BCELoss(None)
criterion = nn.CrossEntropyLoss()
runner = SupervisedRunner()
In [159]:
runner = SupervisedRunner()
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    callbacks=[DiceCallback(), EarlyStoppingCallback(patience=5, min_delta=0.001)],
    logdir=logdir,
    num_epochs=num_epochs,
    verbose=True
)






















1/19 * Epoch (train):   0% 0/312 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-159-20d5dcd769a5> in <module>
      9     logdir=logdir,
     10     num_epochs=num_epochs,
---> 11     verbose=True
     12 )

c:\users\shane\lib\site-packages\catalyst\dl\runner\supervised.py in train(self, model, criterion, optimizer, loaders, logdir, callbacks, scheduler, resume, num_epochs, valid_loader, main_metric, minimize_metric, verbose, state_kwargs, checkpoint_data, fp16, monitoring_params, check)
    204             monitoring_params=monitoring_params
    205         )
--> 206         self.run_experiment(experiment, check=check)
    207 
    208     def infer(

c:\users\shane\lib\site-packages\catalyst\core\runner.py in run_experiment(self, experiment, check)
    380             else:
    381                 self.state.exception = ex
--> 382                 self._run_event("exception", moment=None)
    383 
    384         return self

c:\users\shane\lib\site-packages\catalyst\core\runner.py in _run_event(self, event, moment)
    229                 (moment == "end" or moment is None):  # for on_exception case
    230             for logger in self.loggers.values():
--> 231                 getattr(logger, fn_name)(self.state)
    232 
    233         if self.state is not None:

c:\users\shane\lib\site-packages\catalyst\dl\callbacks\misc.py in on_exception(self, state)
    150 
    151         if state.need_reraise_exception:
--> 152             raise exception
    153 
    154 

c:\users\shane\lib\site-packages\catalyst\core\runner.py in run_experiment(self, experiment, check)
    372         try:
    373             for stage in self.experiment.stages:
--> 374                 self._run_stage(stage)
    375         except (Exception, KeyboardInterrupt) as ex:
    376             # if an exception had been raised

c:\users\shane\lib\site-packages\catalyst\core\runner.py in _run_stage(self, stage)
    341 
    342             self._run_event("epoch", moment="start")
--> 343             self._run_epoch(stage=stage, epoch=epoch)
    344             self._run_event("epoch", moment="end")
    345 

c:\users\shane\lib\site-packages\catalyst\core\runner.py in _run_epoch(self, stage, epoch)
    330             self._run_event("loader", moment="start")
    331             with torch.set_grad_enabled(self.state.need_backward):
--> 332                 self._run_loader(loader)
    333             self._run_event("loader", moment="end")
    334 

c:\users\shane\lib\site-packages\catalyst\core\runner.py in _run_loader(self, loader)
    290         self.state.timer.start("_timers/data_time")
    291 
--> 292         for i, batch in enumerate(loader):
    293             self._run_batch(batch)
    294 

c:\users\shane\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

c:\users\shane\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

c:\users\shane\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

c:\users\shane\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

<ipython-input-71-29d9038310f8> in __getitem__(self, idx)
     18         img = cv2.imread(image_path)
     19         img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
---> 20         augmented = self.transforms(image=img, mask=mask)
     21         img = augmented['image']
     22         mask = augmented['mask']

c:\users\shane\lib\site-packages\albumentations\core\composition.py in __call__(self, force_apply, **data)
    174                     p.preprocess(data)
    175 
--> 176             data = t(force_apply=force_apply, **data)
    177 
    178             if dual_start_end is not None and idx == dual_start_end[1]:

c:\users\shane\lib\site-packages\albumentations\core\transforms_interface.py in __call__(self, force_apply, **kwargs)
     85                     )
     86                 kwargs[self.save_key][id(self)] = deepcopy(params)
---> 87             return self.apply_with_params(params, **kwargs)
     88 
     89         return kwargs

c:\users\shane\lib\site-packages\albumentations\core\transforms_interface.py in apply_with_params(self, params, force_apply, **kwargs)
     98                 target_function = self._get_target_function(key)
     99                 target_dependencies = {k: kwargs[k] for k in self.target_dependence.get(key, [])}
--> 100                 res[key] = target_function(arg, **dict(params, **target_dependencies))
    101             else:
    102                 res[key] = None

c:\users\shane\lib\site-packages\albumentations\augmentations\transforms.py in apply(self, img, stepsx, stepsy, interpolation, **params)
   1218 
   1219     def apply(self, img, stepsx=[], stepsy=[], interpolation=cv2.INTER_LINEAR, **params):
-> 1220         return F.grid_distortion(img, self.num_steps, stepsx, stepsy, interpolation, self.border_mode, self.value)
   1221 
   1222     def apply_to_mask(self, img, stepsx=[], stepsy=[], **params):

c:\users\shane\lib\site-packages\albumentations\augmentations\functional.py in wrapped_function(img, *args, **kwargs)
     52     def wrapped_function(img, *args, **kwargs):
     53         shape = img.shape
---> 54         result = func(img, *args, **kwargs)
     55         result = result.reshape(shape)
     56         return result

c:\users\shane\lib\site-packages\albumentations\augmentations\functional.py in grid_distortion(img, num_steps, xsteps, ysteps, interpolation, border_mode, value)
   1079             cur = prev + y_step * ysteps[idx]
   1080 
-> 1081         yy[start:end] = np.linspace(prev, cur, end - start)
   1082         prev = cur
   1083 

<__array_function__ internals> in linspace(*args, **kwargs)

c:\users\shane\lib\site-packages\numpy\core\function_base.py in linspace(start, stop, num, endpoint, retstep, dtype, axis)
    122 
    123     if num < 0:
--> 124         raise ValueError("Number of samples, %s, must be non-negative." % num)
    125     div = (num - 1) if endpoint else num
    126 

ValueError: Number of samples, -175, must be non-negative.
In [88]:
utils.plot_metrics(
    logdir=logdir, 
    # specify which metrics we want to plot
    metrics=["loss", "dice", 'lr', '_base/lr']
)
In [160]:
encoded_pixels = []
loaders = {"infer": valid_loader}
runner.infer(
    model=model,
    loaders=loaders,
    callbacks=[
        CheckpointCallback(
            resume=f"{logdir}/checkpoints/best.pth"),
        InferCallback()
    ],
)
valid_masks = []
probabilities = np.zeros((2220, 350, 525))
for i, (batch, output) in enumerate(tqdm.tqdm(zip(
        valid_dataset, runner.callbacks[0].predictions["logits"]))):
    image, mask = batch
    for m in mask:
        if m.shape != (350, 525):
            m = cv2.resize(m, dsize=(525, 350), interpolation=cv2.INTER_LINEAR)
        valid_masks.append(m)

    for j, probability in enumerate(output):
        if probability.shape != (350, 525):
            probability = cv2.resize(probability, dsize=(525, 350), interpolation=cv2.INTER_LINEAR)
        probabilities[i * 4 + j, :, :] = probability
---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
<ipython-input-160-f2477811b7df> in <module>
      7         CheckpointCallback(
      8             resume=f"{logdir}/checkpoints/best.pth"),
----> 9         InferCallback()
     10     ],
     11 )

c:\users\shane\lib\site-packages\catalyst\dl\runner\supervised.py in infer(self, model, loaders, callbacks, verbose, state_kwargs, fp16, check)
    248             distributed_params=fp16
    249         )
--> 250         self.run_experiment(experiment, check=check)
    251 
    252     def predict_loader(

c:\users\shane\lib\site-packages\catalyst\core\runner.py in run_experiment(self, experiment, check)
    380             else:
    381                 self.state.exception = ex
--> 382                 self._run_event("exception", moment=None)
    383 
    384         return self

c:\users\shane\lib\site-packages\catalyst\core\runner.py in _run_event(self, event, moment)
    229                 (moment == "end" or moment is None):  # for on_exception case
    230             for logger in self.loggers.values():
--> 231                 getattr(logger, fn_name)(self.state)
    232 
    233         if self.state is not None:

c:\users\shane\lib\site-packages\catalyst\dl\callbacks\misc.py in on_exception(self, state)
    150 
    151         if state.need_reraise_exception:
--> 152             raise exception
    153 
    154 

c:\users\shane\lib\site-packages\catalyst\core\runner.py in run_experiment(self, experiment, check)
    372         try:
    373             for stage in self.experiment.stages:
--> 374                 self._run_stage(stage)
    375         except (Exception, KeyboardInterrupt) as ex:
    376             # if an exception had been raised

c:\users\shane\lib\site-packages\catalyst\core\runner.py in _run_stage(self, stage)
    336         self._prepare_for_stage(stage)
    337 
--> 338         self._run_event("stage", moment="start")
    339         for epoch in range(self.state.num_epochs):
    340             self.state.stage_epoch = epoch

c:\users\shane\lib\site-packages\catalyst\core\runner.py in _run_event(self, event, moment)
    223         if self.callbacks is not None:
    224             for callback in self.callbacks.values():
--> 225                 getattr(callback, fn_name)(self.state)
    226 
    227         # after callbacks

c:\users\shane\lib\site-packages\catalyst\core\callbacks\checkpoint.py in on_stage_start(self, state)
    212 
    213         if self.resume is not None:
--> 214             self.load_checkpoint(filename=self.resume, state=state)
    215 
    216     def on_epoch_end(self, state: _State):

c:\users\shane\lib\site-packages\catalyst\core\callbacks\checkpoint.py in load_checkpoint(filename, state)
    125             )
    126         else:
--> 127             raise Exception(f"No checkpoint found at {filename}")
    128 
    129     def get_metric(self, last_valid_metrics) -> Dict:

Exception: No checkpoint found at ./logs/segmentation/checkpoints/best.pth
In [99]:
torch.cuda.is_available 
Out[99]:
<function torch.cuda.is_available()>
In [100]:
x.cuda()      
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-100-4da82bb84806> in <module>
----> 1 x.cuda()

NameError: name 'x' is not defined
In [ ]: